#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <vector>
#include <cmath>
#include <random>
#include <cub/util_type.cuh>
#include <cub/cub.cuh>

#define MAX_N_GRID 16
#define THREADS_PER_BLOCK 512

__device__ __forceinline__ int hash(int x, int y)
{
    return (x * 73856093) ^ (y * 19349663);
}

__device__ float atomicMaxFloat(float *addr, float value)
{
    float old = *addr;
    float assumed;
    do
    {
        assumed = old;
        old = __int_as_float(atomicCAS((int *)addr,
                                       __float_as_int(assumed),
                                       __float_as_int(fmaxf(value, assumed))));
    } while (assumed != old);
    return old;
}

__global__ void sampled_max_abs_kernel(const float *__restrict__ matrix, int M, int N,
                                       int tile_count, int tile_length, int seed, float *global_max)
{
    using BlockReduce = cub::BlockReduce<float, THREADS_PER_BLOCK>;
    __shared__ typename BlockReduce::TempStorage reduce_storage;
    int thread_id = threadIdx.x;

    int global_sample_idx = blockIdx.x * blockDim.x + threadIdx.x;
    int tile_idx = (abs(hash(global_sample_idx / tile_length, seed)) % tile_count);
    int tile_offset = global_sample_idx % tile_length;

    float val = 0.0f;

    if (global_sample_idx < tile_count * tile_length)
    {
        int tile_y = tile_idx / (N/tile_length);
        int tile_x = tile_idx % (N/tile_length);

        int row = tile_y;
        int col = tile_x * tile_length + tile_offset;

        if (row < M && col < N)
            val = fabsf(matrix[row * N + col]);
    }


    __syncthreads();
    float const block_max = BlockReduce(reduce_storage).Reduce(val, cub::Max(), blockDim.x);
    if (thread_id == 0)
    {
        atomicMaxFloat(global_max, block_max);
    }

}

torch::Tensor sampled_max_abs(torch::Tensor matrix, int tile_count, int tile_length, int seed)
{
    int M = matrix.size(0);
    int N = matrix.size(1);

    int total_samples = tile_count * tile_length;

    int num_blocks = (total_samples + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;

    auto global_max = torch::zeros({1}, torch::device(matrix.device()).dtype(torch::kFloat));

    sampled_max_abs_kernel<<<num_blocks, THREADS_PER_BLOCK>>>(
        matrix.data_ptr<float>(), M, N,
        tile_count, tile_length, seed,
        global_max.data_ptr<float>());

    cudaDeviceSynchronize();

    return global_max;
}


__global__ void scale_grid_search_kernel(const float *__restrict__ matrix, int M, int N, int n_grid, float absmax,
                                         int tile_count ,int tile_length, int seed, float *global_scale_errs )
{
    using BlockReduce = cub::BlockReduce<float, THREADS_PER_BLOCK>;
    __shared__ typename BlockReduce::TempStorage reduce_storage;
    int thread_id = threadIdx.x;

    float thread_errs[MAX_N_GRID];
    for (int i = 0; i < n_grid; i++) {
        thread_errs[i] = 0.0f;
    }

    int global_sample_idx = blockIdx.x * blockDim.x + threadIdx.x;
    int tile_idx = (abs(hash(global_sample_idx / tile_length, seed)) % tile_count);
    int tile_offset = global_sample_idx % tile_length;

    if (global_sample_idx < tile_count * tile_length)
    {
        int tile_y = tile_idx / (N/tile_length);
        int tile_x = tile_idx % (N/tile_length);

        int row = tile_y;
        int col = tile_x * tile_length + tile_offset;

        if (row < M && col < N){
            float val = matrix[row * N + col];

            for (int scale_step = 1; scale_step <= n_grid; scale_step++)
            {
                float scale = (absmax / 127.0f * scale_step / n_grid);
                float q_val = val / scale;
                q_val = fminf(127.0f, fmaxf(-128.0f, q_val));
                q_val = roundf(q_val);
                float error = q_val * scale - val;
                error = error * error;
                thread_errs[scale_step - 1] += error;
            }
        }
    }

    for (int scale_idx = 0; scale_idx < n_grid; scale_idx++)
    {
        float const block_sum = BlockReduce(reduce_storage).Reduce(thread_errs[scale_idx], cub::Sum(), blockDim.x);
        __syncthreads();

        if (thread_id == 0)
        {
            atomicAdd(&global_scale_errs[scale_idx], block_sum);
        }
    }
}

std::tuple<torch::Tensor, torch::Tensor> sampled_scale_grid_search(
        torch::Tensor matrix, float absmax, int n_grid, int tile_count, int tile_length, int seed)
{
    int M = matrix.size(0);
    int N = matrix.size(1);

    int total_samples = tile_count * tile_length;

    int num_blocks = (total_samples + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;

    auto global_scale_errs = torch::zeros({n_grid}, torch::device(matrix.device()).dtype(torch::kFloat));

    int shared_memory_size = THREADS_PER_BLOCK * sizeof(float) * (MAX_N_GRID+1);
    scale_grid_search_kernel<<<num_blocks, THREADS_PER_BLOCK, shared_memory_size>>>(
        matrix.data_ptr<float>(), M, N, n_grid, absmax,
        tile_count, tile_length, seed,
        global_scale_errs.data_ptr<float>());

    cudaDeviceSynchronize();

    cudaError_t err = cudaGetLastError();
    if (err != cudaSuccess) {
        printf("CUDA Error: %s\n", cudaGetErrorString(err));
    }

    auto min_error = global_scale_errs.min();
    int min_scale_idx = global_scale_errs.argmin().item<int>();
    float optimal_scale = (absmax / 127.0f) * (min_scale_idx + 1) / n_grid;
    at::Tensor optimal_scale_tensor = torch::tensor(optimal_scale, torch::dtype(torch::kFloat32).device(matrix.device()));
    return std::make_tuple(global_scale_errs, optimal_scale_tensor);
}


inline __device__ int8_t float_to_int8_rn(float x)
{
//     #ifdef USE_ROCM
//         static constexpr auto i8_min = static_cast<float>(std::numeric_limits<int8_t>::min());
//         static constexpr auto i8_max = static_cast<float>(std::numeric_limits<int8_t>::max());
//
//         // Use nearbyint for rounding to nearest integer
//         float dst = nearbyintf(x);
//
//         // Saturate the result to int8 range
//         dst = fminf(fmaxf(dst, i8_min), i8_max);
//         return static_cast<int8_t>(dst);
//     #else
        // CUDA path
        uint32_t dst;
        asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
        return reinterpret_cast<const int8_t &>(dst);
//     #endif
}

__global__ void quantize_tensor_kernel(
    const float* __restrict__ input, int8_t* __restrict__ output,
    float scale, int num_elements)
{
    // this is strangely slower than pytorch impl
    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (idx < num_elements)
    {
        float val = input[idx] / scale;
        output[idx] = float_to_int8_rn(val);
    }
}

torch::Tensor quantize_tensor(torch::Tensor input_tensor, float scale)
{
    // todo: check for contiguous, device
    // this is strangely slower than pytorch impl
    int num_elements = input_tensor.numel();

    auto output_tensor = torch::empty_like(input_tensor, torch::dtype(torch::kInt8));

    int num_blocks = (num_elements + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;

    quantize_tensor_kernel<<<num_blocks, THREADS_PER_BLOCK>>>(
        input_tensor.data_ptr<float>(), output_tensor.data_ptr<int8_t>(),
        scale, num_elements);

    cudaDeviceSynchronize();
    cudaError_t err = cudaGetLastError();
    if (err != cudaSuccess) {
        printf("CUDA Error in quantize_tensor_kernel: %s\n", cudaGetErrorString(err));
    }

    return output_tensor;
}


std::tuple<torch::Tensor, torch::Tensor> grid_search_quant_int8(
    torch::Tensor input_tensor, int n_grid, float sampling, int seed)
{
    // todo: check for contiguous, device
    const int tile_length = 32;
    const int tile_count = max(1, (int)((input_tensor.numel() * sampling) / tile_length));
    at::Tensor absmax = sampled_max_abs(input_tensor, tile_count, tile_length, seed);
    auto [scale_errs, optimal_scale] = sampled_scale_grid_search(
            input_tensor, absmax.item<float>()*1.1f, n_grid, tile_count, tile_length, seed + 1);
    auto output_tensor = quantize_tensor(input_tensor, optimal_scale.item<float>());
    return std::make_tuple(output_tensor, optimal_scale);
}



// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
// {
//     m.def("sampled_max_abs", &sampled_max_abs, "Find the max absolute value using sampling",
//           py::arg("matrix"), py::arg("tile_count"), py::arg("tile_length"), py::arg("seed"));
//     m.def("sampled_scale_grid_search", &sampled_scale_grid_search, "Find the optimal scale using grid search",
//           py::arg("matrix"), py::arg("absmax"), py::arg("n_grid"), py::arg("tile_count"), py::arg("tile_length"), py::arg("seed"));
//     m.def("quantize_tensor", &quantize_tensor, "Quantize tensor to int8 (CUDA)",
//           py::arg("input_tensor"), py::arg("scale"));
//     m.def("grid_search_quant_int8", &grid_search_quant_int8, "Quantize tensor to int8 using grid search",
//             py::arg("input_tensor"), py::arg("n_grid")=10, py::arg("sampling")=0.1, py::arg("seed")=42);
// }

